import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import adept_envs # type: ignore
import gym
import math
import cv2
import numpy as np
from PIL import Image
import os
import torchvision.transforms as T
from vip import load_vip
import pickle
import time
torch.set_printoptions(edgeitems=10, linewidth=500)
# Basic global variables
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print("Device is", device, flush=True)
vip = load_vip()
vip.eval()
vip = vip.to(device)
transforms = T.Compose([T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor()]) # ToTensor() divides by 255


class Robotic_Environment: # This will only be used for evaluating a trained model
    # Creates the entire robotic environment on pybullet
    def __init__(self, video_resolution, gaussian_noise, camera_number, reset_information, in_hand_eval):
        env = gym.make('kitchen_relax-v1')
        self.env = env.env
        self.camera_number = camera_number
        self.video_frames = [] # These are the frames of video saved for evaluation
        self.video_resolution = video_resolution
        self.env.reset()
        if(gaussian_noise):
            mean = 0  # Mean of the Gaussian noise
            std_dev =0.03  # Standard deviation of the Gaussian noise
            self.env.sim.data.qpos[:] = reset_information[0] * (1 + np.random.normal(mean, std_dev, reset_information[0].shape))
            self.env.sim.data.qvel[:] = reset_information[1] * (1 + np.random.normal(mean, std_dev, reset_information[1].shape))
        else:
            self.env.sim.data.qpos[:] = reset_information[0]
            self.env.sim.data.qvel[:] = reset_information[1]
        self.env.sim.forward() # The environment is setup

    def step(self, action):
        
        self.env.step(np.array(action)) # Execute some action
        curr_frame = self.env.render(mode='rgb_array') # Capture image
        rgb_array = np.array(curr_frame)
        rgb_array = Image.fromarray(rgb_array)
        rgb_array = np.array(rgb_array)
        bgr_array = cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR)
        bgr_array = cv2.resize(bgr_array, self.video_resolution)
        self.video_frames.append(bgr_array)

    def get_current_state(self, space): # This is the state in the format specified as input
        if(space == "joint_space"): # Get the joint state configuration
            return (self.env._get_obs()).tolist()
        elif(space == "image_embedding"): # Get image from all 4 cameras and then make a list of them and return
            img_state = []
            for camera_index in self.camera_number:
                if(camera_index == 2):
                    curr_frame = self.env.render(mode='rgb_array') # Capture image
                    rgb_array = np.array(curr_frame)
                    rgb_array = Image.fromarray(rgb_array)
                    rgb_array = np.array(rgb_array)
                    preprocessed_image = transforms(Image.fromarray(rgb_array.astype(np.uint8))).reshape(-1, 3, 224, 224)
                    preprocessed_image = preprocessed_image.to(device)
                    with torch.no_grad():
                        subgoal_embedding = vip(preprocessed_image * 255.0)
                    img_state.append( subgoal_embedding.cpu().tolist()[0] )
                else:
                    print("Camera angle not available!")
            return img_state

    def save_video(self, video_filename, video_filename_in_hand):
        video_fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        video_out = cv2.VideoWriter(video_filename, video_fourcc, 30.0, self.video_resolution)
        for i in range(0 , len(self.video_frames),4 ): # fast forward 4X
            frame = self.video_frames[i]
            video_out.write(frame)
        video_out.release()

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):  # Assuming 5000 is the maximum length of any trajectory
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, 1, d_model)  # Shape: (max_len, 1, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # Shape: (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # Shape: (d_model/2,)
        pe[:, 0, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices
        pe[:, 0, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices
        self.register_buffer('pe', pe)  # Register as buffer to avoid updating during training

    def forward(self, x):
        seq_len = x.shape[0]  # Get sequence length from input
        x = x + self.pe[:seq_len]  # Add positional encoding (broadcasting over batch)
        return self.dropout(x)  # Apply dropout for regularization

class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, activation="relu",dropout=0.0):

        super(CustomTransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.activation = nn.ReLU() if activation == "relu" else nn.GELU()
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src ):
        attn_output, _ = self.self_attn(src, src, src) # Self-attention
        src = self.norm1(src + attn_output)  # Residual connection
        ff_output = self.linear2(self.activation(self.linear1(src))) # Feedforward network
        src = self.norm2(src + ff_output)  # Residual connection
        return src

class TransformerPolicy_Custom(nn.Module):  # This is transformer policy with customized encoder
    def __init__(self, output_dimension, nhead, num_encoder_layers, dim_feedforward, dropout, activation="relu"):
        super(TransformerPolicy_Custom, self).__init__()
        single_input_dim = 1024  # This is dimension of every single input (token) in the transformer
        self.joint_dim = 60  # Number of joints (60) in Franka kitchen
        self.output_dimension = output_dimension  # This is action
        self.action_token = nn.Parameter(torch.randn(1, 1, single_input_dim))  # Trainable action token
        self.pos_encoder = PositionalEncoding(single_input_dim, dropout)
        
        self.transformer_encoder_layers = nn.ModuleList([
            CustomTransformerEncoderLayer(
                d_model=single_input_dim, nhead=nhead, dim_feedforward=dim_feedforward,
                activation=activation, dropout=dropout
            ) for _ in range(num_encoder_layers)
        ])
        
        self.joint_state_mlp = nn.Sequential(
            nn.Linear(self.joint_dim, 512),
            nn.ReLU(),
            nn.Linear(512, single_input_dim)
        )
        
        self.output_layer = nn.Sequential(
            nn.Linear(single_input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, output_dimension)
        )  # 2-layer MLP output head
        
        self._initialize_weights()

    def forward(self, image_state, goal, joint_state, timestamp):
        joint_state_transformed = self.joint_state_mlp(joint_state)  # Convert joint_state to 1024 dimensions using MLP
        combined_input = torch.cat([image_state, goal.unsqueeze(1), joint_state_transformed.unsqueeze(1)], dim=1)  # Shape: (batch_size, seq_len, 1024)
        action_token = self.action_token.repeat(combined_input.shape[0], 1, 1)  # Shape: (batch_size, 1, 1024)
        combined_input = torch.cat([combined_input, action_token], dim=1)  # Shape: (batch_size, seq_len + 1, 1024)
        combined_input = combined_input.permute(1, 0, 2)  # (seq_len, batch_size, 1024)
        combined_input = self.pos_encoder(combined_input)
        
        for encoder_layer in self.transformer_encoder_layers:
            combined_input = encoder_layer(combined_input)
        
        combined_input = combined_input.permute(1, 0, 2)  # (batch_size, seq_len, 1024)
        action_output = self.output_layer(combined_input[:, -1, :])  # Use the action token's output

        return action_output

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)  # Xavier initialization for linear layers
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.MultiheadAttention):
                nn.init.xavier_uniform_(m.in_proj_weight)
                if m.in_proj_bias is not None:
                    nn.init.zeros_(m.in_proj_bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
        
        nn.init.normal_(self.action_token, mean=0.0, std=0.02)  # Normal initialization for action token


class TrajectoryDataset(Dataset): # Dateset for Behavioural cloning
    def __init__(self, Trajectory_directories, base_directory , subgoal_directory_path, camera_number, action_chunking):
        self.Trajectory_directories = Trajectory_directories # List of all the directories 
        self.base_directory= base_directory
        self.subgoal_directory_path = subgoal_directory_path
        self.camera_number = camera_number
        self.action_chunking = action_chunking
        self.trajectories = self._load_trajectories()

    def _read_csv(self, file_path, directory): 
        with open(file_path, 'rb') as f: # Read the pickel file
            data_dict = pickle.load(f)

        observations = data_dict['observations']  # Shape: (244, 60)
        actions = data_dict['actions']  # Shape: (244, 9)
        data = []
        for i in range(observations.shape[0]):
            observation = observations[i]
            action = actions[i]
            row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
            data.append(row)
        video_paths = [f"{self.base_directory}/{directory}/camera_{camera_angle}.avi" for camera_angle in self.camera_number] # This is list of all the video paths of different camera angles
        for video_path in video_paths:
            cap = cv2.VideoCapture(video_path)
            for i in range(len(data)):
                ret, frame = cap.read()  # ret is a boolean indicating success, frame is the image
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
                preprocessed_image = preprocessed_image.to(device)
                with torch.no_grad():
                    subgoal_embedding = vip(preprocessed_image * 255.0)
                data[i].extend(subgoal_embedding.cpu().tolist()[0])
            cap.release()
        # data is 60(joint) + 3(task) + 1(time) + 9(action) + 1024(image embedding) * number of camera angles
        goal = data[-1][73:1097] # The final frame is the subgoal, camera 2
        output = [] # (image_state, goal, joint_state, action, timestamp) This is the output
        for i in range(len(data)):
            append_to_output = [[data[i][73:][j:j + 1024] for j in range(0, len(data[i][73:]), 1024)]] # This is image_state collection
            joint_state = data[i][0:60]
            action = data[i][64:73]
            for j in range(i+1 , i+self.action_chunking):
                if (j >= len(data)):
                    action+= [0.,0.,0.,0.,0.,0.,0.,0.,0.]
                else:
                    action+= data[j][64:73]
            append_to_output.append(goal)
            append_to_output.append(joint_state)
            append_to_output.append(action) # action
            append_to_output.append(i) # Timestamp
            output.append(append_to_output)
        return output

    def _load_trajectories(self):
        trajectories = []
        for directory in self.Trajectory_directories:
            base_directory = f"{self.base_directory}/{directory}"
            file_path = f"{base_directory}/data.pkl"
            trajectory_data = self._read_csv(file_path, directory)
            for i in range(len(trajectory_data)):
                goal_state_action_pair = trajectory_data[i]
                trajectories.append(goal_state_action_pair)
        return trajectories

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx): # This gives the exact subgoal, state, action , mask , timestamp tuple
        trajectory_data = self.trajectories[idx]
        image_state = torch.tensor(trajectory_data[0] ,  dtype=torch.float32)
        goal = torch.tensor(trajectory_data[1] ,  dtype=torch.float32)
        joint_state = torch.tensor(trajectory_data[2] ,  dtype=torch.float32)
        action = torch.tensor(trajectory_data[3] ,  dtype=torch.float32)
        timestamp = torch.tensor(trajectory_data[4], dtype=torch.float32)
        return (image_state , goal , joint_state, action , timestamp)

def find_largest_number(file_path): # Takes in a directory which contains files of the form number.mp4 and find the largest numbered file inside it
    with open(file_path, 'r') as file:
        lines = file.readlines()
    last_line = lines[-1].strip()
    first_word = last_line.split()[0]
    first_word_int = int(first_word)
    return first_word_int

if __name__ == '__main__':
    # Parameters
    train = True
    eval = True
    output_dimension = 9 # Action will always be 8 dimensional 7 dimension joint angles + 1 dimension gripper

    Trajectory_directories = ['1.1', '1.2', '1.3', '1.4', '1.5', '2.1', '2.2', '2.3', '2.4', '2.5', '3.1', '3.2', '3.3', '3.4', '3.5', '4.1', '4.2', '4.3', '4.4', '4.5', '5.1', '5.2', '5.3', '5.4', '5.5', '6.1', '6.2', '6.3', '6.4', '6.5', '7.1', '7.2', '7.3', '7.4', '7.5', '8.1', '8.2', '8.3', '8.4', '8.5', '9.1', '9.2', '9.3', '9.4', '9.5', '10.1', '10.2', '10.3', '10.4', '10.5', '11.1', '11.2', '11.3', '11.4', '11.5', '12.1', '12.2', '12.3', '12.4', '12.5', '13.1', '13.2', '13.3', '13.4', '13.5', '14.1', '14.2', '14.3', '14.4', '14.5', '15.1', '15.2', '15.3', '15.4', '15.5', '16.1', '16.2', '16.3', '16.4', '16.5', '17.1', '17.2', '17.3', '17.4', '17.5', '18.1', '18.2', '18.3', '18.4', '18.5', '19.1', '19.2', '19.3', '19.4', '19.5', '20.1', '20.2', '20.3', '20.4', '20.5', '21.1', '21.2', '21.3', '21.4', '21.5', '22.1', '22.2', '22.3', '22.4', '22.5', '23.1', '23.2', '23.3', '23.4', '23.5', '24.1', '24.2', '24.3', '24.4', '24.5', '25.1', '25.2', '25.3', '25.4', '25.5']
    list_of_subgoals_directory_to_eval = ['1.1', '1.2', '1.3', '1.4', '1.5', '1.6', '1.7', '1.8', '1.9', '1.10', '2.1', '2.2', '2.3', '2.4', '2.5', '2.6', '2.7', '2.8', '2.9', '2.10', '3.1', '3.2', '3.3', '3.4', '3.5', '3.6', '3.7', '3.8', '3.9', '3.10', '4.1', '4.2', '4.3', '4.4', '4.5', '4.6', '4.7', '4.8', '4.9', '4.10', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8', '5.9', '5.10', '6.1', '6.2', '6.3', '6.4', '6.5', '6.6', '6.7', '6.8', '6.9', '6.10', '7.1', '7.2', '7.3', '7.4', '7.5', '7.6', '7.7', '7.8', '7.9', '7.10', '8.1', '8.2', '8.3', '8.4', '8.5', '8.7', '8.8', '8.9', '8.10', '9.1', '9.2', '9.3', '9.4', '9.5', '9.6', '9.7', '9.8', '9.9', '9.10', '10.1', '10.2', '10.3', '10.4', '10.5', '10.6', '10.7', '10.8', '10.9', '10.10', '11.1', '11.2', '11.3', '11.4', '11.5', '11.6', '11.7', '11.8', '11.9', '11.10', '12.1', '12.2', '12.3', '12.4', '12.5', '12.6', '12.7', '12.8', '12.9', '12.10', '13.1', '13.2', '13.3', '13.4', '13.5', '13.6', '13.7', '13.8', '13.9', '13.10', '14.1', '14.2', '14.3', '14.4', '14.5', '14.6', '14.7', '14.8', '14.9', '14.10', '15.1', '15.2', '15.3', '15.4', '15.5', '15.6', '15.7', '15.8', '15.9', '15.10', '16.1', '16.2', '16.3', '16.4', '16.5', '16.6', '16.7', '16.8', '16.9', '16.10', '17.1', '17.2', '17.3', '17.4', '17.5', '17.6', '17.7', '17.8', '17.9', '17.10', '18.1', '18.2', '18.3', '18.4', '18.5', '18.6', '18.7', '18.8', '18.9', '18.10', '19.1', '19.2', '19.3', '19.4', '19.5', '19.6', '19.7', '19.8', '19.9', '19.10', '20.1', '20.2', '20.3', '20.4', '20.5', '20.6', '20.7', '20.8', '20.9', '20.10', '21.1', '21.2', '21.3', '21.4', '21.5', '21.6', '21.7', '21.8', '21.9', '21.10', '22.1', '22.2', '22.3', '22.4', '22.5', '22.6', '22.7', '22.8', '22.9', '22.10', '23.1', '23.2', '23.3', '23.4', '23.5', '23.6', '23.7', '23.8', '23.9', '23.10', '24.1', '24.2', '24.3', '24.4', '24.5', '24.6', '24.7', '24.8', '24.9', '24.10', '25.1', '25.2', '25.3', '25.4', '25.5', '25.6', '25.7', '25.8', '25.9', '25.10']
    total_number_of_iterations=1 # number of iterations per task (helpful with gaussian noise)
    num_epochs = 1000 # number of epochs on the training dataset
    lr =  0.0003 # learning rate
    dropout=0.1
    nhead=8 # number of attention heads. state dimension must be divisible by attention heads
    num_encoder_layers = 4 # number of encoder layers
    dim_feedforward=1024 # dimension of feedforward network. This is feedforward network inside the transformer layers
    gaussian_noise = False # gaussian noise on the start state
    action_chunking = 10 # Action chunking = 1 means only 1 step prediction
    temporal_ensemble = 0 # Weight given for combining actions
    camera_number = [2] # This is list of all the camera angles going inside BAKU. Keep camera 2 at first because it is used to get the final goal frame. For franka kitchen we only have camera 2
    in_hand_eval = False # Get in hand camera video or not

    output_dimension*=action_chunking
    subgoal_directory_path = f"decomposed_frames/mininterval_18/divisions_1/gamma_0.08/camera_2" # Goal is always camera 2
    saving_formatter = str(find_largest_number("./Parameter_mappings.txt")+1)

    with open('./Parameter_mappings.txt', 'a') as file:
        file.write(f'{saving_formatter}        : num_epochs_{num_epochs}_lr_{lr}_dropout_{dropout}_nhead_{nhead}_num_encoder_layers_{num_encoder_layers}_dim_feedforward_{dim_feedforward}_camera_{camera_number}_gaussian_noise_{gaussian_noise}_action_chunking_{action_chunking}_temporal_ensemble_{temporal_ensemble}_Training_directory_{Trajectory_directories}\n')  # Add a newline character to separate lines
    model_dump_file_path = f"./Trained_Models/{saving_formatter}.pth"
    base_directory = f"./../../Data_Franka_Kitchen"
    model = TransformerPolicy_Custom(output_dimension, nhead, num_encoder_layers, dim_feedforward, dropout).to(device)
    print(model, flush=True)
    total_params = sum(p.numel() for p in model.parameters())
    print("Total number of parameters in the neural network is: ", total_params, flush=True)

    if(train):
        train_start_time = time.time()
        trajectory_dataset = TrajectoryDataset(Trajectory_directories, base_directory , subgoal_directory_path, camera_number, action_chunking)
        data_loader = DataLoader(trajectory_dataset, batch_size=64, shuffle=True )
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) # cosine decay of learning rate
        loss_function = nn.MSELoss()
        # Supervised Learning Loop
        for epoch in range(num_epochs):
            model.train()  # Set the model to training mode
            running_loss = 0.0  # Initialize running loss for the epoch
            num_batches = 0     # Initialize batch counter
            for batch_idx, (image_state , goal , joint_state, action , timestamp) in enumerate(data_loader):
                image_state = image_state.to(device)
                goal = goal.to(device)
                joint_state = joint_state.to(device)
                action = action.to(device)
                timestamp = timestamp.to(device)
                
                optimizer.zero_grad()
                predicted_actions = model(image_state, goal, joint_state, timestamp)
                loss = loss_function(predicted_actions, action)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()  # Accumulate loss
                num_batches += 1             # Increment batch counter

            scheduler.step()

            if(epoch%50 == 0): # Print loss every 50 epochs
                current_lr = optimizer.param_groups[0]['lr'] # Current learning rate
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/num_batches}, Learning Rate: {current_lr}", flush=True)

        torch.save(model.state_dict(), model_dump_file_path)
        print(f"Model saved to {model_dump_file_path}", flush=True)

        print("Time taken to train the model is ", (time.time() - train_start_time)/3600 , " hrs")

    if(eval):
        model.load_state_dict(torch.load(model_dump_file_path))
        model.eval()  # Set the model to evaluation mode
        length_of_trajectories_during_inference = {}
        reset_info_of_trajectories_during_inference = {} # information about the trajectory to infer
        for directory in list_of_subgoals_directory_to_eval:
            file_path = f"{base_directory}/{directory}/data.pkl" # pkl file path
            with open(file_path, 'rb') as f: # Read the pickel file
                data_dict = pickle.load(f)
                length_of_trajectories_during_inference[directory] = data_dict['observations'].shape[0]
                reset_info_of_trajectories_during_inference[directory] = (data_dict['init_qpos'] , data_dict['init_qvel'])

        for iteration_number in range(1,total_number_of_iterations+1,1): # Number of times to evaluate a single trajectory, to get the evaluation metrics
            for directory_for_subgoals in list_of_subgoals_directory_to_eval: # These are all the trajectories to get subgoals from and evaluate
                video_resolution = (224, 224) # This is for resolution for evaluation videos
                reset_information  = reset_info_of_trajectories_during_inference[directory_for_subgoals]
                robot_env = Robotic_Environment(video_resolution, gaussian_noise, camera_number , reset_information, in_hand_eval)

                def robot_inference(directory_for_subgoals): # Function to actually evaluate the neural network
                    max_steps = length_of_trajectories_during_inference[directory_for_subgoals]
                    subgoals_directory = f"{base_directory}/{directory_for_subgoals}/{subgoal_directory_path}"
                    files = os.listdir(subgoals_directory)
                    png_files = [f for f in files if f.endswith('.png')]
                    numbers = [int(f.replace('.png', '')) for f in png_files]
                    list_of_subgoals = sorted(numbers) # This is the sorted list of all the subgoals for some trajectory
                    goal_index = list_of_subgoals[-1]
                    video_path = f"{base_directory}/{directory_for_subgoals}/camera_{2}.avi" # This is path of video for the final goal frame
                    cap = cv2.VideoCapture(video_path)
                    cap.set(cv2.CAP_PROP_POS_FRAMES, goal_index)
                    ret, frame = cap.read()
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
                    preprocessed_image = preprocessed_image.to(device)
                    with torch.no_grad():
                        subgoal_embedding = vip(preprocessed_image * 255.0)
                    goal  = subgoal_embedding.cpu().tolist()[0]
                    cap.release()

                    Buffer = [[] for _ in range(max_steps + action_chunking)]  # Initialize buffer correctly
                    for i in range(max_steps):  # i is the current timestamp
                        timestamp = torch.tensor([i]).to(device)
                        if i % 100 == 0:
                            print(f"Timestamp: {i}/{max_steps}", flush=True)
                        img_state = robot_env.get_current_state("image_embedding")
                        img_state_tensor = torch.tensor([img_state], dtype=torch.float32).to(device)
                        goal_tensor = torch.tensor([goal], dtype= torch.float32).to(device)
                        joint_state = robot_env.get_current_state("joint_space")
                        joint_state_tensor = torch.tensor([joint_state] , dtype = torch.float32).to(device)
                        with torch.no_grad(): 
                            action = model(img_state_tensor, goal_tensor, joint_state_tensor, timestamp)  # Model predicts action chunks
                        action = action.cpu().numpy().flatten()  # Convert to numpy and flatten
                        action = action.reshape(action_chunking, output_dimension // action_chunking)  # Reshape into action chunks

                        for j in range(action_chunking): # Add the action chunks to the buffer
                            Buffer[i + j].append(action[j])
                        weights = np.exp(-temporal_ensemble * np.arange(len(Buffer[i])))  # Perform temporal ensemble: weighted average of the actions
                        weights /= weights.sum()  # Normalize weights
                        current_action = np.sum([w * a for w, a in zip(weights, Buffer[i])], axis=0)
                        current_action = current_action.tolist()  # Convert to list before passing to `step`
                        robot_env.step(current_action)

                print(f"Evaluating subgoals from {directory_for_subgoals}, iteration number {iteration_number}...", flush=True)
                robot_inference(directory_for_subgoals)
                video_path = f"./Evaluation/{saving_formatter}/subgoals_{directory_for_subgoals}"
                os.makedirs(video_path, exist_ok=True) # Directory to save Evaluation Videos
                video_filename = f"{video_path}/{iteration_number}.mp4"
                video_filename_in_hand = f"{video_path}/{iteration_number}_in_hand.mp4"
                robot_env.save_video(video_filename, video_filename_in_hand)